{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Transformer_translate.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true,
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/Styleoshin/tensor2tensor/blob/Transformer_tutorial/tensor2tensor/notebooks/Transformer_translate.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e7PMze9tKHX9",
        "colab_type": "text"
      },
      "source": [
        "# Welcome to the [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor) Colab\n",
        "\n",
        "Tensor2Tensor, or T2T for short, is a library of deep learning models and datasets designed to make deep learning more accessible and [accelerate ML research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html). In this notebook we will see how to use this library for a translation task by exploring the necessary steps. We will see how to define a problem, generate the data, train the model and test the quality of it, and we will translate our sequences and we visualize the attention. We will also see how to download a pre-trained model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KC8jNpnyKJdm",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title\n",
        "# Copyright 2018 Google LLC.\n",
        "\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AYUy570fKRcw",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Install deps\n",
        "!pip install -q -U tensor2tensor"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hEhFfyVNbB_D",
        "colab_type": "text"
      },
      "source": [
        "#1. Initialization\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i23pCAVwegx3",
        "colab_type": "text"
      },
      "source": [
        "##1.1. Make some directories"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oUf4e18_8E31",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import sys\n",
        "if 'google.colab' in sys.modules: # Colab-only TensorFlow version selector\n",
        "  %tensorflow_version 1.x\n",
        "import tensorflow as tf\n",
        "import os\n",
        "\n",
        "DATA_DIR = os.path.expanduser(\"/t2t/data\") # This folder contain the data\n",
        "TMP_DIR = os.path.expanduser(\"/t2t/tmp\")\n",
        "TRAIN_DIR = os.path.expanduser(\"/t2t/train\") # This folder contain the model\n",
        "EXPORT_DIR = os.path.expanduser(\"/t2t/export\") # This folder contain the exported model for production\n",
        "TRANSLATIONS_DIR = os.path.expanduser(\"/t2t/translation\") # This folder contain  all translated sequence\n",
        "EVENT_DIR = os.path.expanduser(\"/t2t/event\") # Test the BLEU score\n",
        "USR_DIR = os.path.expanduser(\"/t2t/user\") # This folder contains our data that we want to add\n",
        " \n",
        "tf.gfile.MakeDirs(DATA_DIR)\n",
        "tf.gfile.MakeDirs(TMP_DIR)\n",
        "tf.gfile.MakeDirs(TRAIN_DIR)\n",
        "tf.gfile.MakeDirs(EXPORT_DIR)\n",
        "tf.gfile.MakeDirs(TRANSLATIONS_DIR)\n",
        "tf.gfile.MakeDirs(EVENT_DIR)\n",
        "tf.gfile.MakeDirs(USR_DIR)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HIuzsMzgbLv9",
        "colab_type": "text"
      },
      "source": [
        "## 1.2. Init parameters\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZQaURmfKBGus",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "PROBLEM = \"translate_enfr_wmt32k\" # We chose a problem translation English to French with 32.768 vocabulary\n",
        "MODEL = \"transformer\" # Our model\n",
        "HPARAMS = \"transformer_big\" # Hyperparameters for the model by default \n",
        "                            # If you have a one gpu, use transformer_big_single_gpu"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EikK-hW5m-ax",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#Show all problems and models \n",
        "\n",
        "from tensor2tensor.utils import registry\n",
        "from tensor2tensor import problems\n",
        "\n",
        "problems.available() #Show all problems\n",
        "registry.list_models() #Show all registered models\n",
        "\n",
        "#or\n",
        "\n",
        "#Command line\n",
        "!t2t-trainer --registry_help #Show all problems\n",
        "!t2t-trainer --problems_help #Show all models"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "78kBAIMQbeO6",
        "colab_type": "text"
      },
      "source": [
        "# 2. Data generation \n",
        "\n",
        "Generate the data (download the dataset and generate the data).\n",
        "\n",
        "---\n",
        "\n",
        " You can choose between command line or code."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CrDy3V7ibpQH",
        "colab_type": "text"
      },
      "source": [
        "## 2.1. Generate with terminal\n",
        "For more information: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/bin/t2t_datagen.py"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0Dfr8nFXmg1o",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!t2t-datagen \\\n",
        "  --data_dir=$DATA_DIR \\\n",
        "  --tmp_dir=$TMP_DIR \\\n",
        "  --problem=$PROBLEM \\\n",
        "  --t2t_usr_dir=$USR_DIR"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tMvCiiBtbuzh",
        "colab_type": "text"
      },
      "source": [
        "## 2.2. Generate with code"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Of5bHYVJmbwH",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "t2t_problem = problems.problem(PROBLEM)\n",
        "t2t_problem.generate_data(DATA_DIR, TMP_DIR) "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UkSwoqBzb47T",
        "colab_type": "text"
      },
      "source": [
        "# 3. Train the model\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1JVF2PJn7ByQ",
        "colab_type": "text"
      },
      "source": [
        "##3.1. Init parameters\n",
        "\n",
        "You can choose between command line or code.\n",
        "\n",
        "---\n",
        "\n",
        " batch_size :  a great value of preference.\n",
        "\n",
        "---\n",
        "train_steps : research paper mentioned 300k steps with 8 gpu on big transformer. So if you have 1 gpu, you will need to train the model x8 more. (https://arxiv.org/abs/1706.03762 for more information).\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yw6HgVWA7AQF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_steps = 300000 # Total number of train steps for all Epochs\n",
        "eval_steps = 100 # Number of steps to perform for each evaluation\n",
        "batch_size = 4096\n",
        "save_checkpoints_steps = 1000\n",
        "ALPHA = 0.1\n",
        "schedule = \"continuous_train_and_eval\""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ze_YvVnIfD8z",
        "colab_type": "text"
      },
      "source": [
        "You can choose schedule :\n",
        " \n",
        "\n",
        "*  train. Bad quality\n",
        "*  continuous_train_and_eval (default)\n",
        "*   train_and_eval\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-zAub7Ggb8tj",
        "colab_type": "text"
      },
      "source": [
        "##3.2. Train with terminal\n",
        "https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/bin/t2t_trainer.py\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kSYAi4BsnpSD",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!t2t-trainer \\\n",
        "  --data_dir=$DATA_DIR \\\n",
        "  --problem=$PROBLEM \\\n",
        "  --model=$MODEL \\\n",
        "  --hparams_set=$HPARAMS \\\n",
        "  --hparams=\"batch_size=$batch_size\" \\\n",
        "  --schedule=$schedule\\\n",
        "  --output_dir=$TRAIN_DIR \\\n",
        "  --train_steps=$train_steps \\\n",
        "  --worker-gpu=1 \\ \n",
        "  --eval_steps=$eval_steps "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bNfNBWtNVMwO",
        "colab_type": "text"
      },
      "source": [
        "  --worker-gpu = 1, for train on 1 gpu (facultative).\n",
        "\n",
        "---\n",
        "\n",
        "For distributed training see: https://github.com/tensorflow/tensor2tensor/blob/master/docs/distributed_training.md\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nnSoC1AUcLG6",
        "colab_type": "text"
      },
      "source": [
        "##3.3. Train with code\n",
        "create_hparams : https://github.com/tensorflow/tensor2tensor/blob/28adf2690c551ef0f570d41bef2019d9c502ec7e/tensor2tensor/utils/hparams_lib.py#L42\n",
        "\n",
        "---\n",
        "Change hyper parameters :\n",
        "https://github.com/tensorflow/tensor2tensor/blob/28adf2690c551ef0f570d41bef2019d9c502ec7e/tensor2tensor/models/transformer.py#L1627\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RJ91vQ2hyIPx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from tensor2tensor.utils.trainer_lib import create_run_config, create_experiment\n",
        "from tensor2tensor.utils.trainer_lib import create_hparams\n",
        "from tensor2tensor.utils import registry\n",
        "from tensor2tensor import models\n",
        "from tensor2tensor import problems\n",
        "\n",
        "# Init Hparams object from T2T Problem\n",
        "hparams = create_hparams(HPARAMS)\n",
        "\n",
        "# Make Changes to Hparams\n",
        "hparams.batch_size = batch_size\n",
        "hparams.learning_rate = ALPHA\n",
        "#hparams.max_length = 256\n",
        "\n",
        "# Can see all Hparams with code below\n",
        "#print(json.loads(hparams.to_json())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KZX1cwK3TEXs",
        "colab_type": "text"
      },
      "source": [
        "create_run_config : https://github.com/tensorflow/tensor2tensor/blob/28adf2690c551ef0f570d41bef2019d9c502ec7e/tensor2tensor/utils/trainer_lib.py#L105\n",
        "\n",
        "---\n",
        "\n",
        "\n",
        "create_experiment : https://github.com/tensorflow/tensor2tensor/blob/28adf2690c551ef0f570d41bef2019d9c502ec7e/tensor2tensor/utils/trainer_lib.py#L611"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yByKcs7XvAXL",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "RUN_CONFIG = create_run_config(\n",
        "      model_dir=TRAIN_DIR,\n",
        "      model_name=MODEL,\n",
        "      save_checkpoints_steps= save_checkpoints_steps\n",
        ")\n",
        "\n",
        "tensorflow_exp_fn = create_experiment(\n",
        "        run_config=RUN_CONFIG,\n",
        "        hparams=hparams,\n",
        "        model_name=MODEL,\n",
        "        problem_name=PROBLEM,\n",
        "        data_dir=DATA_DIR, \n",
        "        train_steps=train_steps, \n",
        "        eval_steps=eval_steps, \n",
        "        #use_xla=True # For acceleration\n",
        "    ) \n",
        "\n",
        "tensorflow_exp_fn.train_and_evaluate()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "03xuR70jce_2",
        "colab_type": "text"
      },
      "source": [
        "#4. See the BLEU score"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MiwyVWPhhGrk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#INIT FILE FOR TRANSLATE\n",
        "\n",
        "SOURCE_TEST_TRANSLATE_DIR = TMP_DIR+\"/dev/newstest2014-fren-src.en.sgm\"\n",
        "REFERENCE_TEST_TRANSLATE_DIR = TMP_DIR+\"/dev/newstest2014-fren-ref.en.sgm\"\n",
        "BEAM_SIZE=1"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "agnSg_89cr63",
        "colab_type": "text"
      },
      "source": [
        "##4.1. Translate all\n",
        "https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/bin/t2t_translate_all.py"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Jrt5fwqsg3pl",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!t2t-translate-all \\\n",
        "  --source=$SOURCE_TEST_TRANSLATE_DIR \\\n",
        "  --model_dir=$TRAIN_DIR \\\n",
        "  --translations_dir=$TRANSLATIONS_DIR \\\n",
        "  --data_dir=$DATA_DIR \\\n",
        "  --problem=$PROBLEM \\\n",
        "  --hparams_set=$HPARAMS \\\n",
        "  --output_dir=$TRAIN_DIR \\\n",
        "  --t2t_usr_dir=$USR_DIR \\\n",
        "  --beam_size=$BEAM_SIZE \\\n",
        "  --model=$MODEL"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O-pKKU2Acv8Q",
        "colab_type": "text"
      },
      "source": [
        "##4.2. Test the BLEU score\n",
        "The BLEU score for all translations: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/bin/t2t_bleu.py#L68\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EULP9TdPc58d",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!t2t-bleu \\\n",
        "   --translations_dir=$TRANSLATIONS_DIR \\\n",
        "   --model_dir=$TRAIN_DIR \\\n",
        "   --data_dir=$DATA_DIR \\\n",
        "   --problem=$PROBLEM \\\n",
        "   --hparams_set=$HPARAMS \\\n",
        "   --source=$SOURCE_TEST_TRANSLATE_DIR \\\n",
        "   --reference=$REFERENCE_TEST_TRANSLATE_DIR \\\n",
        "   --event_dir=$EVENT_DIR"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "13j50bpAc-bM",
        "colab_type": "text"
      },
      "source": [
        "#5. Prediction of sentence\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8WHPnqxhdQl6",
        "colab_type": "text"
      },
      "source": [
        "##5.1. Predict with terminal\n",
        "https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/bin/t2t_decoder.py"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3SD-XhImnwpo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!echo \"the business of the house\" > \"inputs.en\"\n",
        "!echo -e \"les affaires de la maison\" > \"reference.fr\" # You can add other references\n",
        "\n",
        "!t2t-decoder \\\n",
        "  --data_dir=$DATA_DIR \\\n",
        "  --problem=$PROBLEM \\\n",
        "  --model=$MODEL \\\n",
        "  --hparams_set=$HPARAMS \\\n",
        "  --output_dir=$TRAIN_DIR \\\n",
        "  --decode_hparams=\"beam_size=1,alpha=$ALPHA\" \\\n",
        "  --decode_from_file=\"inputs.en\" \\\n",
        "  --decode_to_file=\"outputs.fr\"\n",
        "\n",
        "# See the translations\n",
        "!cat outputs.fr"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sGOC25N4dWdM",
        "colab_type": "text"
      },
      "source": [
        "##5.2. Predict with code"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "S6u4QmhPIbDx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "#After training the model, re-run the environment but run this code in first, then predict.\n",
        "\n",
        "tfe = tf.contrib.eager\n",
        "tfe.enable_eager_execution()\n",
        "Modes = tf.estimator.ModeKeys"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PaCkILfjz9x3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#Config\n",
        "\n",
        "from tensor2tensor import models\n",
        "from tensor2tensor import problems\n",
        "from tensor2tensor.layers import common_layers\n",
        "from tensor2tensor.utils import trainer_lib\n",
        "from tensor2tensor.utils import t2t_model\n",
        "from tensor2tensor.utils import registry\n",
        "from tensor2tensor.utils import metrics\n",
        "import numpy as np\n",
        "\n",
        "enfr_problem = problems.problem(PROBLEM)\n",
        "\n",
        "# Copy the vocab file locally so we can encode inputs and decode model outputs\n",
        "vocab_name = \"vocab.translate_enfr_wmt32k.32768.subwords\"\n",
        "vocab_file = os.path.join(DATA_DIR, vocab_name)\n",
        "\n",
        "# Get the encoders from the problem\n",
        "encoders = enfr_problem.feature_encoders(DATA_DIR)\n",
        "\n",
        "ckpt_path = tf.train.latest_checkpoint(os.path.join(TRAIN_DIR))\n",
        "print(ckpt_path)\n",
        "\n",
        "def translate(inputs):\n",
        "  encoded_inputs = encode(inputs)\n",
        "  with tfe.restore_variables_on_create(ckpt_path):\n",
        "    model_output = translate_model.infer(encoded_inputs)[\"outputs\"]\n",
        "  return decode(model_output)\n",
        "\n",
        "def encode(input_str, output_str=None):\n",
        "  \"\"\"Input str to features dict, ready for inference\"\"\"\n",
        "  inputs = encoders[\"inputs\"].encode(input_str) + [1]  # add EOS id\n",
        "  batch_inputs = tf.reshape(inputs, [1, -1, 1])  # Make it 3D.\n",
        "  return {\"inputs\": batch_inputs}\n",
        "\n",
        "def decode(integers):\n",
        "  \"\"\"List of ints to str\"\"\"\n",
        "  integers = list(np.squeeze(integers))\n",
        "  if 1 in integers:\n",
        "    integers = integers[:integers.index(1)]\n",
        "  return encoders[\"inputs\"].decode(np.squeeze(integers))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5zE8yHLUA2He",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#Predict \n",
        "\n",
        "hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)\n",
        "translate_model = registry.model(MODEL)(hparams, Modes.PREDICT)\n",
        "\n",
        "inputs = \"the aniamal didn't cross the river because it was too tired\"\n",
        "ref = \"l'animal n'a pas traversé la rue parcequ'il etait trop fatigué\" ## this just a reference for evaluate the quality of the traduction\n",
        "outputs = translate(inputs)\n",
        "\n",
        "print(\"Inputs: %s\" % inputs)\n",
        "print(\"Outputs: %s\" % outputs)\n",
        "\n",
        "file_input = open(\"outputs.fr\",\"w+\")\n",
        "file_input.write(outputs)\n",
        "file_input.close()\n",
        "\n",
        "file_output = open(\"reference.fr\",\"w+\")\n",
        "file_output.write(ref)\n",
        "file_output.close()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y6jbQ6FoRsmG",
        "colab_type": "text"
      },
      "source": [
        "##5.3. Evaluate the BLEU Score\n",
        "BLEU score for a sequence translation: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/bin/t2t_bleu.py#L24"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "il2oevmXRrbf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!t2t-bleu \\\n",
        "    --translation=outputs.fr \\\n",
        "    --reference=reference.fr"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FXegHzD1I67e",
        "colab_type": "text"
      },
      "source": [
        "#6. Attention visualization\n",
        "We need to have a predicted sentence with code."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ISHauPT8I-3S",
        "colab_type": "text"
      },
      "source": [
        "##6.1. Attention utils\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2RHCTrc9I55K",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from tensor2tensor.visualization import attention\n",
        "from tensor2tensor.data_generators import text_encoder\n",
        "\n",
        "SIZE = 35\n",
        "\n",
        "def encode_eval(input_str, output_str):\n",
        "  inputs = tf.reshape(encoders[\"inputs\"].encode(input_str) + [1], [1, -1, 1, 1])  # Make it 3D.\n",
        "  outputs = tf.reshape(encoders[\"inputs\"].encode(output_str) + [1], [1, -1, 1, 1])  # Make it 3D.\n",
        "  return {\"inputs\": inputs, \"targets\": outputs}\n",
        "\n",
        "def get_att_mats():\n",
        "  enc_atts = []\n",
        "  dec_atts = []\n",
        "  encdec_atts = []\n",
        "\n",
        "  for i in range(hparams.num_hidden_layers):\n",
        "    enc_att = translate_model.attention_weights[\n",
        "      \"transformer/body/encoder/layer_%i/self_attention/multihead_attention/dot_product_attention\" % i][0]\n",
        "    dec_att = translate_model.attention_weights[\n",
        "      \"transformer/body/decoder/layer_%i/self_attention/multihead_attention/dot_product_attention\" % i][0]\n",
        "    encdec_att = translate_model.attention_weights[\n",
        "      \"transformer/body/decoder/layer_%i/encdec_attention/multihead_attention/dot_product_attention\" % i][0]\n",
        "    enc_atts.append(resize(enc_att))\n",
        "    dec_atts.append(resize(dec_att))\n",
        "    encdec_atts.append(resize(encdec_att))\n",
        "  return enc_atts, dec_atts, encdec_atts\n",
        "\n",
        "def resize(np_mat):\n",
        "  # Sum across heads\n",
        "  np_mat = np_mat[:, :SIZE, :SIZE]\n",
        "  row_sums = np.sum(np_mat, axis=0)\n",
        "  # Normalize\n",
        "  layer_mat = np_mat / row_sums[np.newaxis, :]\n",
        "  lsh = layer_mat.shape\n",
        "  # Add extra dim for viz code to work.\n",
        "  layer_mat = np.reshape(layer_mat, (1, lsh[0], lsh[1], lsh[2]))\n",
        "  return layer_mat\n",
        "\n",
        "def to_tokens(ids):\n",
        "  ids = np.squeeze(ids)\n",
        "  subtokenizer = hparams.problem_hparams.vocabulary['targets']\n",
        "  tokens = []\n",
        "  for _id in ids:\n",
        "    if _id == 0:\n",
        "      tokens.append('<PAD>')\n",
        "    elif _id == 1:\n",
        "      tokens.append('<EOS>')\n",
        "    elif _id == -1:\n",
        "      tokens.append('<NULL>')\n",
        "    else:\n",
        "        tokens.append(subtokenizer._subtoken_id_to_subtoken_string(_id))\n",
        "  return tokens\n",
        "\n",
        "def call_html():\n",
        "  import IPython\n",
        "  display(IPython.core.display.HTML('''\n",
        "        <script src=\"/static/components/requirejs/require.js\"></script>\n",
        "        <script>\n",
        "          requirejs.config({\n",
        "            paths: {\n",
        "              base: '/static/base',\n",
        "              \"d3\": \"https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min\",\n",
        "              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n",
        "            },\n",
        "          });\n",
        "        </script>\n",
        "        '''))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9PGwUbJuJHJS",
        "colab_type": "text"
      },
      "source": [
        "##6.2 Display Attention"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ijTOlrt8JI4t",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "\n",
        "# Convert inputs and outputs to subwords\n",
        "\n",
        "inp_text = to_tokens(encoders[\"inputs\"].encode(inputs))\n",
        "out_text = to_tokens(encoders[\"inputs\"].encode(outputs))\n",
        "\n",
        "hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)\n",
        "\n",
        "# Run eval to collect attention weights\n",
        "example = encode_eval(inputs, outputs)\n",
        "with tfe.restore_variables_on_create(tf.train.latest_checkpoint(ckpt_path)):\n",
        "  translate_model.set_mode(Modes.EVAL)\n",
        "  translate_model(example)\n",
        "# Get normalized attention weights for each layer\n",
        "enc_atts, dec_atts, encdec_atts = get_att_mats()\n",
        "\n",
        "call_html()\n",
        "attention.show(inp_text, out_text, enc_atts, dec_atts, encdec_atts)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r8yAQUDZdm1p",
        "colab_type": "text"
      },
      "source": [
        "#7. Export the model\n",
        "For more information: https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/serving"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c2yulC7J8_I9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#export Model\n",
        "!t2t-exporter \\\n",
        "  --data_dir=$DATA_DIR \\\n",
        "  --output_dir=$TRAIN_DIR \\\n",
        "  --problem=$PROBLEM \\\n",
        "  --model=$MODEL \\\n",
        "  --hparams_set=$HPARAMS \\\n",
        "  --decode_hparams=\"beam_size=1,alpha=$ALPHA\" \\\n",
        "  --export_dir=$EXPORT_DIR"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2ltjEr3JX5-e",
        "colab_type": "text"
      },
      "source": [
        "#8.Load pretrained model from Google Storage\n",
        "We use the pretrained model En-De translation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QgY3Fw261bZC",
        "colab_type": "text"
      },
      "source": [
        "##8.1. See existing content storaged"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7P7aJClG0t8c",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "print(\"checkpoint: \")\n",
        "!gsutil ls \"gs://tensor2tensor-checkpoints\"\n",
        "\n",
        "print(\"data: \")\n",
        "!gsutil ls \"gs://tensor2tensor-data\""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wP8jrR5bbu7e",
        "colab_type": "text"
      },
      "source": [
        "##8.2. Init model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AnYU7lrazkMm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "PROBLEM_PRETRAINED = \"translate_ende_wmt32k\"\n",
        "MODEL_PRETRAINED = \"transformer\" \n",
        "HPARAMS_PRETRAINED = \"transformer_base\""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DTgPvq4q1VAr",
        "colab_type": "text"
      },
      "source": [
        "##8.3. Load content from google storage"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FrxOAVcyinll",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import tensorflow as tf\n",
        "import os\n",
        "\n",
        "\n",
        "DATA_DIR_PRETRAINED = os.path.expanduser(\"/t2t/data_pretrained\")\n",
        "CHECKPOINT_DIR_PRETRAINED = os.path.expanduser(\"/t2t/checkpoints_pretrained\")\n",
        "\n",
        "tf.gfile.MakeDirs(DATA_DIR_PRETRAINED)\n",
        "tf.gfile.MakeDirs(CHECKPOINT_DIR_PRETRAINED)\n",
        "\n",
        "\n",
        "gs_data_dir = \"gs://tensor2tensor-data/\"\n",
        "vocab_name = \"vocab.translate_ende_wmt32k.32768.subwords\"\n",
        "vocab_file = os.path.join(gs_data_dir, vocab_name)\n",
        "\n",
        "gs_ckpt_dir = \"gs://tensor2tensor-checkpoints/\"\n",
        "ckpt_name = \"transformer_ende_test\"\n",
        "gs_ckpt = os.path.join(gs_ckpt_dir, ckpt_name)\n",
        "\n",
        "TRAIN_DIR_PRETRAINED = os.path.join(CHECKPOINT_DIR_PRETRAINED, ckpt_name)\n",
        "\n",
        "!gsutil cp {vocab_file} {DATA_DIR_PRETRAINED}\n",
        "!gsutil -q cp -R {gs_ckpt} {CHECKPOINT_DIR_PRETRAINED}\n",
        "\n",
        "CHECKPOINT_NAME_PRETRAINED = tf.train.latest_checkpoint(TRAIN_DIR_PRETRAINED) # for translate with code\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LP6cro9Xbygf",
        "colab_type": "text"
      },
      "source": [
        "##8.4. Translate"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CBoNpy5HbzoF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!echo \"the business of the house\" > \"inputs.en\"\n",
        "!echo -e \"das Geschäft des Hauses\" > \"reference.de\"\n",
        "\n",
        "!t2t-decoder \\\n",
        "  --data_dir=$DATA_DIR_PRETRAINED \\\n",
        "  --problem=$PROBLEM_PRETRAINED  \\\n",
        "  --model=$MODEL_PRETRAINED  \\\n",
        "  --hparams_set=$HPARAMS_PRETRAINED \\\n",
        "  --output_dir=$TRAIN_DIR_PRETRAINED  \\\n",
        "  --decode_hparams=\"beam_size=1\" \\\n",
        "  --decode_from_file=\"inputs.en\" \\\n",
        "  --decode_to_file=\"outputs.de\"\n",
        "\n",
        "# See the translations\n",
        "!cat outputs.de\n",
        "\n",
        "!t2t-bleu \\\n",
        "    --translation=outputs.de \\\n",
        "    --reference=reference.de"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bKI4WF0DgoFd",
        "colab_type": "text"
      },
      "source": [
        "#9.  Add your dataset/problem\n",
        "To add a new dataset/problem, subclass Problem and register it with @registry.register_problem. See TranslateEnfrWmt8k for an example: \n",
        "https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/translate_enfr.py\n",
        "\n",
        "---\n",
        "Adding your own components: https://github.com/tensorflow/tensor2tensor#adding-your-own-components\n",
        "\n",
        "---\n",
        "\n",
        "See this example: https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/test_data/example_usr_dir"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mB1SIrJNqy1N",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from tensor2tensor.utils import registry\n",
        "\n",
        "@registry.register_problem\n",
        "class MyTranslateEnFr(translate_enfr.TranslateEnfrWmt8k):\n",
        "\n",
        "  def generator(self, data_dir, tmp_dir, train):\n",
        "   #your code"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
